{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bd35935c",
   "metadata": {},
   "source": [
    "# MVKM-ED: Advanced Multi-View Clustering Demo\n",
    "\n",
    "This notebook demonstrates the advanced capabilities of our MVKM-ED implementation, including:\n",
    "- GPU acceleration\n",
    "- Advanced visualizations\n",
    "- Performance metrics\n",
    "- Data preprocessing\n",
    "- Model persistence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51170447",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from mvkm_ed import MVKMED, MVKMEDConfig\n",
    "from mvkm_ed.utils import MVKMEDVisualizer, MVKMEDMetrics, MVKMEDDataProcessor\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "735574d4",
   "metadata": {},
   "source": [
    "## 1. Generate Sample Multi-View Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fc9250b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate synthetic multi-view data\n",
    "np.random.seed(42)\n",
    "\n",
    "def generate_multiview_data(n_samples=300, view_dims=[10, 15], n_clusters=3):\n",
    "    views = []\n",
    "    centers = np.random.randn(n_clusters, max(view_dims))\n",
    "    labels = np.random.randint(0, n_clusters, n_samples)\n",
    "    \n",
    "    for dim in view_dims:\n",
    "        view_data = centers[labels, :dim] + 0.1 * np.random.randn(n_samples, dim)\n",
    "        views.append(view_data)\n",
    "    \n",
    "    return views, labels\n",
    "\n",
    "X, true_labels = generate_multiview_data()\n",
    "print(f\"Generated {len(X)} views of shapes: {[x.shape for x in X]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8465bdc3",
   "metadata": {},
   "source": [
    "## 2. Preprocess Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2565b99",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess the views\n",
    "processor = MVKMEDDataProcessor()\n",
    "X_processed = processor.preprocess_views(X, scale=True, normalize=True)\n",
    "print(\"Data preprocessing complete\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d517682f",
   "metadata": {},
   "source": [
    "## 3. Configure and Initialize Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e0118f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configure the model\n",
    "config = MVKMEDConfig(\n",
    "    cluster_num=3,\n",
    "    points_view=len(X),\n",
    "    alpha=2.0,\n",
    "    beta=0.1,\n",
    "    random_state=42,\n",
    "    device=\"auto\",  # Will use GPU if available\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "# Initialize model\n",
    "model = MVKMED(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d30cc983",
   "metadata": {},
   "source": [
    "## 4. Fit Model and Visualize Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f901ade1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fit the model\n",
    "model.fit(X_processed)\n",
    "\n",
    "# Create visualizer\n",
    "viz = MVKMEDVisualizer(model)\n",
    "\n",
    "# Plot convergence\n",
    "plt.figure(figsize=(15, 5))\n",
    "plt.subplot(121)\n",
    "viz.plot_convergence()\n",
    "\n",
    "plt.subplot(122)\n",
    "viz.plot_view_weights()\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26f82238",
   "metadata": {},
   "source": [
    "## 5. Evaluate Clustering Quality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bed46a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute clustering metrics\n",
    "metrics = MVKMEDMetrics.compute_metrics(X_processed, model.labels_, true_labels)\n",
    "\n",
    "print(\"\\nClustering Metrics:\")\n",
    "for metric, value in metrics.items():\n",
    "    print(f\"{metric}: {value:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a63810c",
   "metadata": {},
   "source": [
    "## 6. Save and Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a25247b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mvkm_ed.utils import MVKMEDPersistence\n",
    "\n",
    "# Save model\n",
    "MVKMEDPersistence.save_model(model, 'mvkm_ed_model.joblib')\n",
    "\n",
    "# Load model\n",
    "loaded_model = MVKMEDPersistence.load_model('mvkm_ed_model.joblib')\n",
    "print(\"Model successfully saved and loaded\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
